import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import os
import torch.optim as optim
from torch.utils.data import DataLoader

from modelnet import ModelNet40
from .models import SetTransformer
from lssot import LSSOT

import torch

def generate_uniform_points_on_sphere_torch(N = 1024):
    """
    Generate N points uniformly distributed on the unit sphere using the Fibonacci lattice method in PyTorch.

    Parameters:
    N (int): Number of points to generate.

    Returns:
    torch.Tensor: A tensor of shape (N, 3) containing the coordinates of the points.
    """
    # Golden ratio
    phi = (1 + torch.sqrt(torch.tensor(5.0))) / 2

    # Indices
    i = torch.arange(0, N, dtype=torch.float32) + 0.5

    # Azimuthal angle uniformly distributed between 0 and 2π
    theta = 2 * torch.pi * i / phi

    # z-coordinate uniformly distributed between -1 and 1
    z = 1 - (2 * i / N)

    # Corresponding x and y coordinates on the unit sphere
    x = torch.sqrt(1 - z**2) * torch.cos(theta)
    y = torch.sqrt(1 - z**2) * torch.sin(theta)

    # Stack the coordinates into an (N, 3) tensor
    points = torch.stack((x, y, z), dim=1)

    return points



# Define dataset and dataloader
def get_dataloaders(num_points, batch_size):
    train_dataset = ModelNet40(num_points=num_points, partition='train')
    valid_dataset = ModelNet40(num_points=num_points, partition='valid')
    test_dataset = ModelNet40(num_points=num_points, partition='test')
    mini_data = Subset(train_dataset, [0])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last= True)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, drop_last= True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last= True)
    mini_data_loader = DataLoader(mini_data, batch_size=1, shuffle=False)

    return train_loader, valid_loader, test_loader, mini_data_loader

def train_autoencoder(train_loader, valid_loader, model, criterion, optimizer, device, num_epochs=50, save_interval=10, model_path='checkpoints/best_autoencoder.pth'):
    best_loss = float('inf')
    model.to(device)
    train_losses = []
    valid_losses = []
    coeff = 0.0001

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        for pointclouds, _ in train_loader:
            pointclouds = pointclouds.to(device).float()

            optimizer.zero_grad()

            res = model(pointclouds)
            
            latent = res[0]
            reconstructed_pointclouds = res[1]
            # print(reconstructed_pointclouds.shape)
            # print(pointclouds.shape)
            loss = criterion(reconstructed_pointclouds, pointclouds)
            weights = (torch.ones(1024) / 1024).to(device)

            for pt in latent:
                loss += coeff * LSSOT(num_projections=1000, ref_size=1024, device=device)(pt, weights)
            # print(loss)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        
        
        print(pointclouds.shape, len(train_loader))

        train_loss = running_loss / len(train_loader)
        train_losses.append(train_loss)

        print(f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss:.4f}')

    
        if (epoch + 1) % save_interval == 0:
            
            valid_loss = validate_autoencoder(valid_loader, model, criterion, device)
            valid_losses.append(valid_loss)

            interval_model_path = f'checkpoints/autoencoder5_epoch_{epoch + 1}.pth'
            torch.save(model.state_dict(), interval_model_path)
            print(f'Saved model at epoch {epoch + 1}')

            if valid_loss < best_loss:
                best_loss = valid_loss
                torch.save(model.state_dict(), model_path)
                print(f'Saved best model with loss: {best_loss:.4f}')
            
            plot_losses(train_losses, valid_losses, epoch + 1)

    return model

def plot_losses(train_losses, valid_losses, epoch):
    plt.figure()
    plt.plot(range(1, epoch + 1), train_losses, label='Train Loss')
    # plt.plot(range(1, epoch + 1), valid_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training Loss')
    print('validation losses:', valid_losses)
    plt.savefig(f'loss/loss_plot_epoch_{epoch}.png')
    plt.close()

def validate_autoencoder(valid_loader, model, criterion, device):
    model.eval()
    running_loss = 0.0

    with torch.no_grad():
        for pointclouds, _ in valid_loader:
            pointclouds = pointclouds.to(device).float()
            reconstructed_pointclouds = model(pointclouds)[1]
            loss = criterion(reconstructed_pointclouds, pointclouds)
            running_loss += loss.item()
    
    print(pointclouds.shape)

    valid_loss = running_loss / len(valid_loader)
    return valid_loss

def load_model(model, path):
    if os.path.exists(path):
        model.load_state_dict(torch.load(path))
        print(f"Model loaded from {path}")
    else:
        print(f"Model file {path} does not exist.")
    return model

if __name__ == '__main__':
    num_points = 1024
    batch_size = 32
    num_epochs = 100
    save_interval = 10
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    model_path = 'checkpoints/settrans_nonlinear_entropy_lssot.pth'
    load_pretrained = True  # Set this flag to True to load a pre-trained model


    torch.manual_seed(3407)
    m = 512
    sigma = 2** 4
    B_matrix =  B_matrix = torch.randn((m, 3), device=device) * sigma


    train_loader, valid_loader, test_loader, mini = get_dataloaders(num_points, batch_size)
    autoencoder = SetTransformer(dim_input=3, num_outputs=1024, dim_output=3)# PointCloudAutoencoder(input_dim=3, num_points=num_points)
    

    
    criterion = nn.MSELoss() # ChamferLossCriterion()
    optimizer = optim.Adam(autoencoder.parameters(), lr=1e-6)

    if load_pretrained:
        autoencoder = load_model(autoencoder, model_path)
        print(f"Loaded pre-trained model from {model_path}")
        # for param in autoencoder.enc.parameters():
        #     param.requires_grad = False
        #     print("Fine tune the decoder ")
        autoencoder = train_autoencoder(train_loader, valid_loader, autoencoder, criterion, optimizer, device, num_epochs, save_interval, model_path)
    else:
        autoencoder = train_autoencoder(train_loader, valid_loader, autoencoder, criterion, optimizer, device, num_epochs, save_interval, model_path)